{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Code for **\"Activation maximization\"** figure."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can select net type (`vgg_16_caffe`, `vgg19_caffe`, `alexnet`) and a layer. For your reference the layer names for each network type are shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vgg_19_names=['conv1_1','relu1_1','conv1_2','relu1_2','pool1',\n",
    "              'conv2_1','relu2_1','conv2_2','relu2_2','pool2',\n",
    "              'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','conv3_4','relu3_4','pool3',\n",
    "              'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','conv4_4','relu4_4','pool4',\n",
    "              'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','conv5_4','relu5_4','pool5',\n",
    "              'torch_view','fc6','relu6','drop6','fc7','relu7','drop7','fc8']\n",
    "\n",
    "vgg_16_names = ['conv1_1','relu1_1','conv1_2','relu1_2','pool1',\n",
    "                'conv2_1','relu2_1','conv2_2','relu2_2','pool2',\n",
    "                'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','pool3',\n",
    "                'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','pool4',\n",
    "                'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','pool5',\n",
    "                'torch_view','fc6','relu6','drop6','fc7','relu7','fc8']\n",
    "\n",
    "alexnet_names = ['conv1', 'relu1', 'norm1', 'pool1',\n",
    "                 'conv2', 'relu2', 'norm2', 'pool2',\n",
    "                 'conv3', 'relu3', 'conv4', 'relu4',\n",
    "                 'conv5', 'relu5', 'pool5', 'torch_view',\n",
    "                 'fc6', 'relu6', 'drop6',\n",
    "                 'fc7', 'relu7', 'drop7',\n",
    "                 'fc8', 'softmax']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The actual code starts here."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import argparse\n",
    "import os\n",
    "# os.environ['CUDA_VISIBLE_DEVICES'] = '3'\n",
    "\n",
    "import numpy as np\n",
    "from models import *\n",
    "\n",
    "import torch\n",
    "import torch.optim\n",
    "\n",
    "from utils.perceptual_loss.perceptual_loss import *\n",
    "from utils.common_utils import *\n",
    "\n",
    "torch.backends.cudnn.enabled = True\n",
    "torch.backends.cudnn.benchmark =True\n",
    "dtype = torch.cuda.FloatTensor\n",
    "\n",
    "PLOT = True\n",
    "fname = './data/feature_inversion/building.jpg'\n",
    "\n",
    "# Choose net type\n",
    "pretrained_net = 'alexnet_caffe' \n",
    "assert pretrained_net in ['alexnet_caffe', 'vgg19_caffe', 'vgg16_caffe']\n",
    "\n",
    "# Choose layers\n",
    "layer_to_use = 'conv4'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('data/imagenet1000_clsid_to_human.txt', 'r') as f:\n",
    "    corresp = json.load(f)\n",
    "    \n",
    "\n",
    "if layer_to_use == 'fc8':\n",
    "    # Choose class\n",
    "    name = 'black swan'\n",
    "    # name = 'cheeseburger'\n",
    "\n",
    "    map_idx = None\n",
    "    for k,v in corresp.items():\n",
    "        if name in v:\n",
    "            map_idx = int(k)\n",
    "            break\n",
    "else:\n",
    "    map_idx = 2 # Choose here"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup pretrained net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Target imsize \n",
    "imsize = 227 if pretrained_net == 'alexnet_caffe' else 224\n",
    "\n",
    "# Something divisible by a power of two\n",
    "imsize_net = 256\n",
    "\n",
    "# VGG and Alexnet need input to be correctly normalized\n",
    "preprocess, deprocess = get_preprocessor(imsize), get_deprocessor()\n",
    "\n",
    "\n",
    "img_content_pil, img_content_np  = get_image(fname, -1)\n",
    "img_content_prerocessed = preprocess(img_content_pil)[None,:].type(dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "opt_content = {'layers': [layer_to_use], 'what':'features', 'map_idx': map_idx}\n",
    "\n",
    "cnn = get_pretrained_net(pretrained_net).type(dtype)\n",
    "cnn.add_module('softmax', nn.Softmax())\n",
    "\n",
    "# Remove the layers we don't need \n",
    "keys = [x for x in cnn._modules.keys()]\n",
    "max_idx = max(keys.index(x) for x in opt_content['layers'])\n",
    "for k in keys[max_idx+1:]:\n",
    "    cnn._modules.pop(k)\n",
    "    \n",
    "print(cnn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "matcher_content = get_matcher(cnn, opt_content)\n",
    "matcher_content.mode = 'match'\n",
    "\n",
    "if layer_to_use == 'fc8':\n",
    "    matcher_content.mode = 'match'\n",
    "    LR = 0.01\n",
    "else:\n",
    "    \n",
    "    # Choose here\n",
    "    # Window size controls the width of the region where the activations are maximized\n",
    "    matcher_content.window_size = 20 # if = 1 then it is neuron maximization\n",
    "    matcher_content.method = 'maximize'\n",
    "    LR = 0.001"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup matcher and net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT = 'noise'\n",
    "input_depth = 32\n",
    "OPTIMIZER = 'adam'\n",
    "net_input = get_noise(input_depth, INPUT, imsize_net).type(dtype).detach()\n",
    "OPT_OVER = 'net' #'net,input'\n",
    "pad='reflection'\n",
    "\n",
    "tv_weight=0.0\n",
    "reg_noise_std = 0.03\n",
    "param_noise = True\n",
    "num_iter = 3100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = skip(input_depth, 3, num_channels_down = [16, 32, 64, 128, 128, 128],\n",
    "                           num_channels_up =   [16, 32, 64, 128, 128, 128],\n",
    "                           num_channels_skip = [0, 4, 4, 4, 4, 4],   \n",
    "                           filter_size_down = [5, 3, 5, 5, 3, 5], filter_size_up = [5, 3, 5, 3, 5, 3], \n",
    "                           upsample_mode='bilinear', downsample_mode='avg',\n",
    "                           need_sigmoid=True, pad=pad, act_fun='LeakyReLU').type(dtype)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Compute number of parameters\n",
    "s  = sum(np.prod(list(p.size())) for p in net.parameters())\n",
    "print ('Number of params: %d' % s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net(net_input).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TV"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Uncomment this section if you do not wan to optimize over pixels with TV prior only."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# INPUT = 'noise'\n",
    "# input_depth = 3\n",
    "# net_input = (get_noise(input_depth, INPUT, imsize_net).type(dtype)+0.5).detach()\n",
    "\n",
    "# OPT_OVER = 'input' #'net,input'\n",
    "# net = nn.Sequential()\n",
    "# reg_noise_std =0.0\n",
    "# OPTIMIZER = 'adam'# 'LBFGS'\n",
    "# LR = 0.01\n",
    "# tv_weight=1e-6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optimize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = net_input.clone()[:,:3,:imsize,:imsize] * 0\n",
    "for i in range(imsize):\n",
    "    for j in range(imsize):\n",
    "        d = np.sqrt((i - imsize//2)**2 + (j - imsize//2)**2)\n",
    "#         if d > 75:\n",
    "        mask[:,:, i, j] = 1 - min(100./d, 1)\n",
    "            \n",
    "plot_image_grid([torch_to_np(mask)]);\n",
    "use_mask = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.sr_utils import tv_loss\n",
    "\n",
    "net_input_saved = net_input.detach().clone()\n",
    "noise = net_input.detach().clone()\n",
    "\n",
    "\n",
    "outs = [] \n",
    "\n",
    "def closure():\n",
    "    \n",
    "    global i, net_input\n",
    "    \n",
    "    if param_noise:\n",
    "        for n in [x for x in net.parameters() if len(x.size()) == 4]:\n",
    "            n = n + n.detach().clone().normal_() * n.std()/50\n",
    "    \n",
    "    net_input = net_input_saved\n",
    "    if reg_noise_std > 0:\n",
    "        net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n",
    "\n",
    "    out = net(net_input)[:, :, :imsize, :imsize]\n",
    "    \n",
    "#     out = out* (1-mask)\n",
    "   \n",
    "        \n",
    "    cnn(vgg_preprocess_caffe(out))\n",
    "    total_loss =  sum(matcher_content.losses.values()) * 5\n",
    "    \n",
    "    if tv_weight > 0:\n",
    "        total_loss += tv_weight * tv_loss(vgg_preprocess_caffe(out), beta=2)\n",
    "        \n",
    "    \n",
    "    if use_mask:\n",
    "        total_loss += nn.functional.mse_loss(out * mask, mask * 0, size_average=False) * 1e1\n",
    "    \n",
    "    total_loss.backward()\n",
    "\n",
    "    print ('Iteration %05d    Loss %.3f' % (i, total_loss.item()), '\\r', end='')\n",
    "    if PLOT and  i % 100==0:\n",
    "        out_np = np.clip(torch_to_np(out), 0, 1)\n",
    "        plot_image_grid([out_np], 3, 3, interpolation='lanczos');\n",
    "        \n",
    "        outs.append(out_np)\n",
    "    i += 1\n",
    "    \n",
    "    return total_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i=0\n",
    "\n",
    "p = get_params(OPT_OVER, net, net_input)\n",
    "\n",
    "optimize(OPTIMIZER, p, closure, LR, num_iter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = net(net_input)[:, :, :imsize, :imsize]\n",
    "plot_image_grid([torch_to_np(out)], 3, 3);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
